{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Customize a TabNet Model\n",
    "\n",
    "## This tutorial gives examples on how to easily customize a TabNet Model\n",
    "\n",
    "### 1 - Customizing your learning rate scheduler\n",
    "\n",
    "Almost all classical pytroch schedulers are now easy to integrate with pytorch-tabnet\n",
    "\n",
    "### 2 - Use your own loss function\n",
    "\n",
    "It's really easy to use any pytorch loss function with TabNet, we'll walk you through that\n",
    "\n",
    "\n",
    "### 3 - Customizing your evaluation metric and evaluations sets\n",
    "\n",
    "Like XGBoost, you can easily monitor different metrics on different evaluation sets with pytorch-tabnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_tabnet.tab_model import TabNetClassifier\n",
    "\n",
    "import torch\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "\n",
    "\n",
    "import os\n",
    "import wget\n",
    "from pathlib import Path\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download census-income dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n",
    "dataset_name = 'census-income'\n",
    "out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out.parent.mkdir(parents=True, exist_ok=True)\n",
    "if out.exists():\n",
    "    print(\"File already exists.\")\n",
    "else:\n",
    "    print(\"Downloading file...\")\n",
    "    wget.download(url, out.as_posix())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load data and split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv(out)\n",
    "target = ' <=50K'\n",
    "if \"Set\" not in train.columns:\n",
    "    train[\"Set\"] = np.random.choice([\"train\", \"valid\", \"test\"], p =[.8, .1, .1], size=(train.shape[0],))\n",
    "\n",
    "train_indices = train[train.Set==\"train\"].index\n",
    "valid_indices = train[train.Set==\"valid\"].index\n",
    "test_indices = train[train.Set==\"test\"].index"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Simple preprocessing\n",
    "\n",
    "Label encode categorical features and fill empty cells."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nunique = train.nunique()\n",
    "types = train.dtypes\n",
    "\n",
    "categorical_columns = []\n",
    "categorical_dims =  {}\n",
    "for col in train.columns:\n",
    "    if types[col] == 'object' or nunique[col] < 200:\n",
    "        print(col, train[col].nunique())\n",
    "        l_enc = LabelEncoder()\n",
    "        train[col] = train[col].fillna(\"VV_likely\")\n",
    "        train[col] = l_enc.fit_transform(train[col].values)\n",
    "        categorical_columns.append(col)\n",
    "        categorical_dims[col] = len(l_enc.classes_)\n",
    "    else:\n",
    "        train.fillna(train.loc[train_indices, col].mean(), inplace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define categorical features for categorical embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unused_feat = ['Set']\n",
    "\n",
    "features = [ col for col in train.columns if col not in unused_feat+[target]] \n",
    "\n",
    "cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n",
    "\n",
    "cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1 - Customizing your learning rate scheduler\n",
    "\n",
    "TabNetClassifier, TabNetRegressor and TabNetMultiTaskClassifier all takes two arguments:\n",
    "- scheduler_fn : Any torch.optim.lr_scheduler should work\n",
    "- scheduler_params : A dictionnary that contains the parameters of your scheduler (without the optimizer)\n",
    "\n",
    "----\n",
    "NB1 : Some schedulers like torch.optim.lr_scheduler.ReduceLROnPlateau depend on the evolution of a metric, pytorch-tabnet will use the early stopping metric you asked (the last eval_metric, see 2-) to perform the schedulers updates\n",
    "\n",
    "EX1 : \n",
    "```\n",
    "scheduler_fn=torch.optim.lr_scheduler.ReduceLROnPlateau\n",
    "scheduler_params={\"mode\":'max', # max because default eval metric for binary is AUC\n",
    "                 \"factor\":0.1,\n",
    "                 \"patience\":1}\n",
    "```\n",
    "\n",
    "-----\n",
    "NB2 : Some schedulers require updates at batch level, they can be used very easily the only thing to do is to add `is_batch_level` to True in your `scheduler_params`\n",
    "\n",
    "EX2:\n",
    "```\n",
    "scheduler_fn=torch.optim.lr_scheduler.CyclicLR\n",
    "scheduler_params={\"is_batch_level\":True,\n",
    "                  \"base_lr\":1e-3,\n",
    "                  \"max_lr\":1e-2,\n",
    "                  \"step_size_up\":100\n",
    "                  }\n",
    "```\n",
    "\n",
    "-----\n",
    "NB3: Note that you can also customize your optimizer function, any torch optimizer should work"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Network parameters\n",
    "max_epochs = 20 if not os.getenv(\"CI\", False) else 2\n",
    "batch_size = 1024\n",
    "clf = TabNetClassifier(cat_idxs=cat_idxs,\n",
    "                       cat_dims=cat_dims,\n",
    "                       cat_emb_dim=1,\n",
    "                       optimizer_fn=torch.optim.Adam, # Any optimizer works here\n",
    "                       optimizer_params=dict(lr=2e-2),\n",
    "                       scheduler_fn=torch.optim.lr_scheduler.OneCycleLR,\n",
    "                       scheduler_params={\"is_batch_level\":True,\n",
    "                                         \"max_lr\":5e-2,\n",
    "                                         \"steps_per_epoch\":int(train.shape[0] / batch_size)+1,\n",
    "                                         \"epochs\":max_epochs\n",
    "                                          },\n",
    "                       mask_type='entmax', # \"sparsemax\",\n",
    "                      )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = train[features].values[train_indices]\n",
    "y_train = train[target].values[train_indices]\n",
    "\n",
    "X_valid = train[features].values[valid_indices]\n",
    "y_valid = train[target].values[valid_indices]\n",
    "\n",
    "X_test = train[features].values[test_indices]\n",
    "y_test = train[target].values[test_indices]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2 - Use your own loss function\n",
    "\n",
    "The default loss for classification is torch.nn.functional.cross_entropy\n",
    "\n",
    "The default loss for regression is torch.nn.functional.mse_loss\n",
    "\n",
    "Any derivable loss function of the type lambda y_pred, y_true : loss(y_pred, y_true) should work if it uses torch computation (to allow gradients computations).\n",
    "\n",
    "In particular, any pytorch loss function should work.\n",
    "\n",
    "Once your loss is defined simply pass it loss_fn argument when defining your model.\n",
    "\n",
    "/!\\ : One important thing to keep in mind is that when computing the loss for TabNetClassifier and TabNetMultiTaskClassifier you'll need to apply first torch.nn.Softmax() to y_pred as the final model prediction is softmaxed automatically.\n",
    "\n",
    "NB : Tabnet also has an internal loss (the sparsity loss) which is summed to the loss_fn, the importance of the sparsity loss can be mitigated using `lambda_sparse` parameter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_loss_fn(y_pred, y_true):\n",
    "    \"\"\"\n",
    "    Dummy example similar to using default torch.nn.functional.cross_entropy\n",
    "    \"\"\"\n",
    "    softmax_pred = torch.nn.Softmax(dim=-1)(y_pred)\n",
    "    logloss = (1-y_true)*torch.log(softmax_pred[:,0])\n",
    "    logloss += y_true*torch.log(softmax_pred[:,1])\n",
    "    return -torch.mean(logloss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3 - Customizing your evaluation metric and evaluations sets\n",
    "\n",
    "When calling the `fit` method you can speficy:\n",
    "- eval_set : a list of tuples like (X_valid, y_valid)\n",
    "    Note that the last value of this list will be used for early stopping\n",
    "- eval_name : a list to name each eval set\n",
    "    default will be val_0, val_1 ...\n",
    "- eval_metric : a list of default metrics or custom metrics\n",
    "    Default : \"auc\", \"accuracy\", \"logloss\", \"balanced_accuracy\", \"mse\", \"rmse\"\n",
    "    \n",
    "    \n",
    "NB : If no eval_set is given no early stopping will occure (patience is then ignored) and the weights used will be the last epoch's weights\n",
    "\n",
    "NB2 : If `patience<=0` this will disable early stopping\n",
    "\n",
    "NB3 : Setting `patience` to `max_epochs` ensures that training won't be early stopped, but best weights from the best epochs will be used (instead of the last weight if early stopping is disabled)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_tabnet.metrics import Metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class my_metric(Metric):\n",
    "    \"\"\"\n",
    "    2xAUC.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        self._name = \"custom\" # write an understandable name here\n",
    "        self._maximize = True\n",
    "\n",
    "    def __call__(self, y_true, y_score):\n",
    "        \"\"\"\n",
    "        Compute AUC of predictions.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        y_true: np.ndarray\n",
    "            Target matrix or vector\n",
    "        y_score: np.ndarray\n",
    "            Score matrix or vector\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "            float\n",
    "            AUC of predictions vs targets.\n",
    "        \"\"\"\n",
    "        return 2*roc_auc_score(y_true, y_score[:, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf.fit(\n",
    "    X_train=X_train, y_train=y_train,\n",
    "    eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
    "    eval_name=['train', 'val'],\n",
    "    eval_metric=[\"auc\", my_metric],\n",
    "    max_epochs=max_epochs , patience=0,\n",
    "    batch_size=batch_size,\n",
    "    virtual_batch_size=128,\n",
    "    num_workers=0,\n",
    "    weights=1,\n",
    "    drop_last=False,\n",
    "    loss_fn=my_loss_fn\n",
    ") \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plt.plot(clf.history['loss'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot auc\n",
    "plt.plot(clf.history['train_auc'])\n",
    "plt.plot(clf.history['val_auc'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot learning rates\n",
    "plt.plot(clf.history['lr'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = clf.predict_proba(X_test)\n",
    "test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)\n",
    "\n",
    "preds_valid = clf.predict_proba(X_valid)\n",
    "valid_auc = roc_auc_score(y_score=preds_valid[:,1], y_true=y_valid)\n",
    "\n",
    "\n",
    "print(f\"FINAL VALID SCORE FOR {dataset_name} : {clf.history['val_auc'][-1]}\")\n",
    "print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_auc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check that last epoch's weight are used\n",
    "assert np.isclose(valid_auc, clf.history['val_auc'][-1], atol=1e-6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save and load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save tabnet model\n",
    "saving_path_name = \"./tabnet_model_test_1\"\n",
    "saved_filepath = clf.save_model(saving_path_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define new model with basic parameters and load state dict weights\n",
    "loaded_clf = TabNetClassifier()\n",
    "loaded_clf.load_model(saved_filepath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_preds = loaded_clf.predict_proba(X_test)\n",
    "loaded_test_auc = roc_auc_score(y_score=loaded_preds[:,1], y_true=y_test)\n",
    "\n",
    "print(f\"FINAL TEST SCORE FOR {dataset_name} : {loaded_test_auc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert(test_auc == loaded_test_auc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Global explainability : feat importance summing to 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf.feature_importances_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Local explainability and masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "explain_matrix, masks = clf.explain(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 3, figsize=(20,20))\n",
    "\n",
    "for i in range(3):\n",
    "    axs[i].imshow(masks[i][:50])\n",
    "    axs[i].set_title(f\"mask {i}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# XGB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from xgboost import XGBClassifier\n",
    "\n",
    "clf_xgb = XGBClassifier(max_depth=8,\n",
    "    learning_rate=0.1,\n",
    "    n_estimators=1000,\n",
    "    verbosity=0,\n",
    "    silent=None,\n",
    "    objective='binary:logistic',\n",
    "    booster='gbtree',\n",
    "    n_jobs=-1,\n",
    "    nthread=None,\n",
    "    gamma=0,\n",
    "    min_child_weight=1,\n",
    "    max_delta_step=0,\n",
    "    subsample=0.7,\n",
    "    colsample_bytree=1,\n",
    "    colsample_bylevel=1,\n",
    "    colsample_bynode=1,\n",
    "    reg_alpha=0,\n",
    "    reg_lambda=1,\n",
    "    scale_pos_weight=1,\n",
    "    base_score=0.5,\n",
    "    random_state=0,\n",
    "    seed=None,)\n",
    "\n",
    "clf_xgb.fit(X_train, y_train,\n",
    "        eval_set=[(X_valid, y_valid)],\n",
    "        early_stopping_rounds=40,\n",
    "        verbose=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = np.array(clf_xgb.predict_proba(X_valid))\n",
    "valid_auc = roc_auc_score(y_score=preds[:,1], y_true=y_valid)\n",
    "print(valid_auc)\n",
    "\n",
    "preds = np.array(clf_xgb.predict_proba(X_test))\n",
    "test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)\n",
    "print(test_auc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
